require(ggplot2)
require(reshape2)
require(viridis)

delta <- rho <- seq(0.05, 1, length.out = length.out)

sim_res <- list()
for (i in 1:length(delta)) {
  fn <- paste("./nochange/corr00/res_pred/sim_pred_mse_corr00_", i, ".rds", sep = '')
  sim_res[[i]] <-  readRDS(fn)
}

## data prep for plot
mse_lasso  <- do.call("cbind", lapply(sim_res, function(x) x[,1]))
mse_bridge <- do.call("cbind", lapply(sim_res, function(x) x[,2]))
mse_ridge  <- do.call("cbind", lapply(sim_res, function(x) x[,3]))
mse_en     <- do.call("cbind", lapply(sim_res, function(x) x[,4]))

mse_lasso_dat  <- data.frame(mse_lasso)
mse_bridge_dat <- data.frame(mse_bridge)
mse_ridge_dat  <- data.frame(mse_ridge)
mse_en_dat     <- data.frame(mse_en)

colnames(mse_lasso_dat) <- colnames(mse_bridge_dat) <- colnames(mse_ridge_dat) <- colnames(mse_en_dat) <- paste("delta", delta, sep = '')
mse_lasso_dat$rho  <- rho
mse_bridge_dat$rho <- rho
mse_ridge_dat$rho  <- rho
mse_en_dat$rho     <- rho

mse_lasso_dat$method  <- "Lasso"
mse_bridge_dat$method <- "Bayes Bridge"
mse_ridge_dat$method  <- "Ridge"
mse_en_dat$method     <- "ElasticNet"

mse_full <- rbind(mse_lasso_dat, mse_bridge_dat, mse_ridge_dat, mse_en_dat)

mse_full <- melt(mse_full, id.vars = c("rho", "method"), variable.name = "delta", value.name = 'L2')
mse_full$delta <- delta[as.numeric(mse_full$delta)]
mse_full$method <- relevel(factor(mse_full$method), 'Bayes Bridge')


## theoretical boundary
boundary <- function(rho) { exp(1 - 1/rho - log(rho)) }
delta_hat <- boundary(rho)
delta_line <- subset(data.frame(rho, delta_hat), delta_hat > 0.05)

mse_full$L21 <- round(mse_full$L2, 3)
# fig_save <- ggplot(data = subset(mse_full, method != "Ridge"), aes(x = delta, y = rho)) + geom_tile(aes(fill = L2)) +
#     scale_fill_distiller(palette = "YlGnBu", direction = 1) +
#     # scale_fill_viridis(option = "magma", direction = -1) +
#     labs(x = expression(paste(alpha, " = ", N / p)), y  = expression(paste(rho, " = ", k / N)), fill = "") +
#     scale_x_continuous(expand = c(0,0)) +
#     scale_y_continuous(expand = c(0,0)) +
#     facet_wrap(~method) +
#     theme_bw() +
#     theme(panel.background = element_blank(),
#             panel.grid = element_blank(),
#             panel.border = element_rect(color = 'gray30', size = 0.7),
#             strip.background = element_blank(),
#             strip.text = element_text(face = 'bold', size = 12),
#             axis.text.x = element_text(size = 11, color = 'gray30'),
#             axis.text.y = element_text(size = 11, color = 'gray30'),
#             axis.title.x = element_text(size = 12.5),
#             axis.title.y = element_text(size = 12.5))

# fig_save <- ggplot(mse_full, aes(x = delta, y = rho)) + geom_tile(aes(fill = L21)) +
#     scale_fill_distiller(palette = "YlGnBu", direction = 1) +
#     # scale_fill_viridis(option = "magma", direction = -1) +
#     labs(x = expression(paste(delta, " = ", N / p)), y  = expression(paste(rho, " = ", k / N)), fill = "") +
#     scale_x_continuous(expand = c(0,0)) +
#     scale_y_continuous(expand = c(0,0)) +
#     facet_wrap(~method) +
#     theme_bw() +
#     theme(panel.background = element_blank(),
#             panel.grid = element_blank(),
#             panel.border = element_rect(color = 'gray30', size = 0.7),
#             strip.background = element_blank(),
#             strip.text = element_text(face = 'bold', size = 12),
#             axis.text.x = element_text(size = 11, color = 'gray30'),
#             axis.text.y = element_text(size = 11, color = 'gray30'),
#             axis.title.x = element_text(size = 12.5),
#             axis.title.y = element_text(size = 12.5))
#
# ggsave(fig_save, file = 'L2_pred_error_corr0.pdf', height = 6.5, width = 7)

fig_save <- ggplot(mse_full, aes(x = delta, y = rho)) + geom_tile(aes(fill = L21)) +
    scale_fill_distiller(palette = "RdYlBu") +
    # scale_fill_viridis(option = "magma", direction = -1) +
    labs(x = expression(paste(delta, " = ", n / p)), y  = expression(paste(rho, " = ", k / n)), fill = expression(paste("x", 10^5))) +
    scale_x_continuous(expand = c(0,0)) +
    scale_y_continuous(expand = c(0,0)) +
    facet_wrap(~method, ncol = 4) +
    ggtitle("Panel A (Prediction Loss)") +
    theme_bw() +
    theme(panel.background = element_blank(),
            panel.grid = element_blank(),
            panel.border = element_rect(color = 'gray30', size = 0.7),
            plot.title = element_text(face = 'bold'),
            strip.background = element_blank(),
            strip.text = element_text(face = 'bold', size = 12),
            axis.text.x = element_text(size = 11, color = 'gray30'),
            axis.text.y = element_text(size = 11, color = 'gray30'),
            axis.title.x = element_text(size = 12.5),
            axis.title.y = element_text(size = 12.5))
ggsave(fig_save, file = 'SI/Figure5_Panel_A.pdf', height = 2.5, width = 7.3)
